# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MindSpore Serving Client"""

import grpc
import numpy as np
import mindspore_serving.proto.ms_service_pb2 as ms_service_pb2
import mindspore_serving.proto.ms_service_pb2_grpc as ms_service_pb2_grpc


def _create_tensor(data, tensor=None):
    """Create tensor from numpy data"""
    if tensor is None:
        tensor = ms_service_pb2.Tensor()

    tensor.shape.dims.extend(data.shape)
    dtype_map = {
        np.bool: ms_service_pb2.MS_BOOL,
        np.int8: ms_service_pb2.MS_INT8,
        np.uint8: ms_service_pb2.MS_UINT8,
        np.int16: ms_service_pb2.MS_INT16,
        np.uint16: ms_service_pb2.MS_UINT16,
        np.int32: ms_service_pb2.MS_INT32,
        np.uint32: ms_service_pb2.MS_UINT32,

        np.int64: ms_service_pb2.MS_INT64,
        np.uint64: ms_service_pb2.MS_UINT64,
        np.float16: ms_service_pb2.MS_FLOAT16,
        np.float32: ms_service_pb2.MS_FLOAT32,
        np.float64: ms_service_pb2.MS_FLOAT64,
    }
    for k, v in dtype_map.items():
        if k == data.dtype:
            tensor.dtype = v
            break
    if tensor.dtype == ms_service_pb2.MS_UNKNOWN:
        raise RuntimeError("Unknown data type " + str(data.dtype))
    tensor.data = data.tobytes()
    return tensor


def _create_scalar_tensor(vals, tensor=None):
    """Create tensor from scalar data"""
    if not isinstance(vals, (tuple, list)):
        vals = (vals,)
    return _create_tensor(np.array(vals), tensor)


def _create_bytes_tensor(bytes_vals, tensor=None):
    """Create tensor from bytes data"""
    if tensor is None:
        tensor = ms_service_pb2.Tensor()

    if not isinstance(bytes_vals, (tuple, list)):
        bytes_vals = (bytes_vals,)
    tensor.shape.dims.extend([len(bytes_vals)])
    tensor.dtype = ms_service_pb2.MS_BYTES
    for item in bytes_vals:
        tensor.bytes_val.append(item)
    return tensor


def _create_str_tensor(str_vals, tensor=None):
    """Create tensor from str data"""
    if tensor is None:
        tensor = ms_service_pb2.Tensor()

    if not isinstance(str_vals, (tuple, list)):
        str_vals = (str_vals,)
    tensor.shape.dims.extend([len(str_vals)])
    tensor.dtype = ms_service_pb2.MS_STRING
    for item in str_vals:
        tensor.bytes_val.append(bytes(item, encoding="utf8"))
    return tensor


def _create_numpy_from_tensor(tensor):
    """Create numpy from protobuf tensor"""
    dtype_map = {
        ms_service_pb2.MS_BOOL: np.bool,
        ms_service_pb2.MS_INT8: np.int8,
        ms_service_pb2.MS_UINT8: np.uint8,
        ms_service_pb2.MS_INT16: ms_service_pb2.MS_INT16,
        ms_service_pb2.MS_UINT16: np.uint16,
        ms_service_pb2.MS_INT32: np.int32,
        ms_service_pb2.MS_UINT32: np.uint32,

        ms_service_pb2.MS_INT64: np.int64,
        ms_service_pb2.MS_UINT64: np.uint64,
        ms_service_pb2.MS_FLOAT16: np.float16,
        ms_service_pb2.MS_FLOAT32: np.float32,
        ms_service_pb2.MS_FLOAT64: np.float64,
    }
    if tensor.dtype == ms_service_pb2.MS_STRING or tensor.dtype == ms_service_pb2.MS_BYTES:
        result = []
        for item in tensor.bytes_val:
            if tensor.dtype == ms_service_pb2.MS_STRING:
                result.append(bytes.decode(item))
            else:
                result.append(item)
        if len(result) == 1:
            return result[0]
        return result

    result = np.frombuffer(tensor.data, dtype_map[tensor.dtype]).reshape(tensor.shape.dims)
    return result


def _check_str(arg_name, str_val):
    """Check whether the input parameters are reasonable str input"""
    if not isinstance(str_val, str):
        raise RuntimeError(f"Parameter '{arg_name}' should be str, but actually {type(str_val)}")
    if not str_val:
        raise RuntimeError(f"Parameter '{arg_name}' should not be empty str")


def _check_int(arg_name, int_val, mininum=None, maximum=None):
    """Check whether the input parameters are reasonable int input"""
    if not isinstance(int_val, int):
        raise RuntimeError(f"Parameter '{arg_name}' should be int, but actually {type(int_val)}")
    if mininum is not None and int_val < mininum:
        if maximum is not None:
            raise RuntimeError(f"Parameter '{arg_name}' should be in range [{mininum},{maximum}]")
        raise RuntimeError(f"Parameter '{arg_name}' should be >= {mininum}")
    if maximum is not None and int_val > maximum:
        if mininum is not None:
            raise RuntimeError(f"Parameter '{arg_name}' should be in range [{mininum},{maximum}]")
        raise RuntimeError(f"Parameter '{arg_name}' should be <= {maximum}")


class Client:
    """
    The Client encapsulates the serving gRPC API, which can be used to create requests,
    access serving, and parse results.

    Args:
        ip (str): Serving ip.
        port (int): Serving port.
        servable_name (str): The name of servable supplied by Serving.
        method_name (str): The name of method supplied by servable.
        version_number (int): The version number of servable, default 0,
            which means the maximum version number in all running versions.
    Raises:
        RuntimeError: The type or value of the parameters is invalid, or other errors happened.

    Examples:
        >>> from mindspore_serving.client import Client
        >>> import numpy as np
        >>> client = Client("localhost", 5500, "add", "add_cast")
        >>> instances = []
        >>> x1 = np.ones((2, 2), np.int32)
        >>> x2 = np.ones((2, 2), np.int32)
        >>> instances.append({"x1": x1, "x2": x2})
        >>> result = client.infer(instances)
        >>> print(result)
    """

    def __init__(self, ip, port, servable_name, method_name, version_number=0):
        _check_str("ip", ip)
        _check_int("port", port, 1, 65535)
        _check_str("servable_name", servable_name)
        _check_str("method_name", method_name)
        _check_int("version_number", version_number, 0)

        self.ip = ip
        self.port = port
        self.servable_name = servable_name
        self.method_name = method_name
        self.version_number = version_number

        channel_str = str(ip) + ":" + str(port)
        msg_bytes_size = 512 * 1024 * 1024 # 512MB
        channel = grpc.insecure_channel(channel_str,
                                        options=[
                                            ('grpc.max_send_message_length', msg_bytes_size),
                                            ('grpc.max_receive_message_length', msg_bytes_size),
                                        ])
        self.stub = ms_service_pb2_grpc.MSServiceStub(channel)

    def infer(self, instances):
        """
        Used to create requests, access serving, and parse results.

        Args:
            instances (map, tuple of map): Instance or tuple of instances, every instance item is the inputs map.
                The map key is the input name, and the value is the input value.

        Raises:
            RuntimeError: The type or value of the parameters is invalid, or other errors happened.
        """
        if not isinstance(instances, (tuple, list)):
            instances = (instances,)
        request = self._create_request()
        for item in instances:
            if isinstance(item, dict):
                request.instances.append(self._create_instance(**item))
            else:
                raise RuntimeError("instance should be a map")

        try:
            result = self.stub.Predict(request)
            return self._paser_result(result)

        except grpc.RpcError as e:
            print(e.details())
            status_code = e.code()
            print(status_code.name)
            print(status_code.value)
            return {"error": "Grpc Error, " + str(status_code.value)}

    def _create_request(self):
        """Used to create request spec."""
        request = ms_service_pb2.PredictRequest()
        request.servable_spec.name = self.servable_name
        request.servable_spec.method_name = self.method_name
        request.servable_spec.version_number = self.version_number
        return request

    def _create_instance(self, **kwargs):
        """Used to create gRPC instance."""
        instance = ms_service_pb2.Instance()
        for k, w in kwargs.items():
            tensor = instance.items[k]
            if isinstance(w, (np.ndarray, np.number)):
                _create_tensor(w, tensor)
            elif isinstance(w, str):
                _create_str_tensor(w, tensor)
            elif isinstance(w, (bool, int, float)):
                _create_scalar_tensor(w, tensor)
            elif isinstance(w, bytes):
                _create_bytes_tensor(w, tensor)
            else:
                raise RuntimeError("Not support value type " + str(type(w)))
        return instance

    def _paser_result(self, result):
        """Used to parse result."""
        error_msg_len = len(result.error_msg)
        if error_msg_len == 1:
            return {"error": bytes.decode(result.error_msg[0].error_msg)}
        ret_val = []
        instance_len = len(result.instances)
        if error_msg_len not in (0, instance_len):
            raise RuntimeError(f"error msg result size {error_msg_len} not be 0, 1 or "
                               f"length of instances {instance_len}")
        for i in range(instance_len):
            instance = result.instances[i]
            if error_msg_len == 0 or result.error_msg[i].error_code == 0:
                instance_map = {}
                for k, w in instance.items.items():
                    instance_map[k] = _create_numpy_from_tensor(w)
                ret_val.append(instance_map)
            else:
                ret_val.append({"error": bytes.decode(result.error_msg[i].error_msg)})
        return ret_val
